import numpy as np
# import pykalman
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib as mpl
import scipy.stats as stats
import ekf_testv6 as ekf
from sklearn import datasets, linear_model
from sklearn.linear_model import LinearRegression
import statsmodels.api as sm
plt.rcParams["font.family"] = "Times New Roman"
mpl.rcParams['mathtext.fontset'] = "cm"
import matplotlib.transforms as mtransforms

#data2 = np.genfromtxt(open("HadCRUT5.global.annual.csv", "rb"),dtype=float, delimiter=',')
temps=ekf.temps  #data2[:, 1]
#JonesOffset=13.85
#sdate= 1850
#offset= -np.mean(temps[(1960-sdate):(1990-sdate)]) +JonesOffset + 273.15 #Jones2013 13.7 to 14 (Jones1999)
#temps=temps+offset

plt.rcParams['figure.figsize'] = (5,4)
fig, ax = plt.subplots()
plt.subplots_adjust(left=0.2, right=0.8)
updN=10
ekf.plot_boilerplate(ax)
plt.yticks(np.arange(286.2,288.4,0.2))

plt.title(str(ekf.N)+'-year Mean, Updated Every '+str(updN)+' years')
for yea in np.arange(ekf.fN,(len(temps)-ekf.cN+1),updN):
    year=yea+1850;
    ax.add_patch(Rectangle(((year-ekf.cN+1) ,(ekf.moving_aves[yea]-ekf.std_aves[yea]/np.sqrt(ekf.N))), ekf.N ,ekf.std_aves[yea]/np.sqrt(ekf.N)*2, facecolor=ekf.colorstate))
    plt.plot([(year-ekf.cN+1),(year+ekf.fN+0.9)],[ekf.moving_aves[yea],ekf.moving_aves[yea]],linewidth=1, color=ekf.colorekf)
    plt.errorbar(year,ekf.moving_aves[yea],ekf.std_aves[yea]*2, capsize=2, color=ekf.coloruncert );

plt.plot(ekf.dates,temps,'o',label='HadCRUT5 GMST measurements $Y_{t}$',markersize=2,color=ekf.colorgrey)
plt.plot((0,1),(0,1),'-',label='Averages: Standard Climate Normals',markersize=2,color=ekf.colorekf)
plt.plot((0,1),(0,1),'-',label='± Standard Deviation * 2',markersize=2,color=ekf.coloruncert)
plt.plot((0,1),(0,1),'-',label='± Standard Error * 2',markersize=2,color=ekf.colorstate)
plt.legend(loc="best",prop={'size':8})
axb = ax.twinx()
mn, mx = ax.get_ylim()
axb.set_yticks(np.arange(13,15.2,0.2))
axb.set_ylim(mn- 273.15, mx- 273.15)
axb.set_ylabel('Temperature (°C)')
fig.savefig("30_yr_averages.pdf",format="pdf")
fig.savefig("30_yr_averages.png", dpi=400,format="png")

diffpts30m=-ekf.moving_aves[ekf.fN:(-ekf.cN+1)]+temps[ekf.fN:(-ekf.cN+1)]

r2 = ekf.r2_score(ekf.moving_aves[ekf.fN:(-ekf.cN+1)], ekf.xblind[ekf.fN:(-ekf.cN+1),0])
print('r2 score for blind vs 30yr-mean is', r2)

axlabels=['a)','b)','c)','d)']

fig2=plt.figure(3, figsize=(12,5))
plt.subplots_adjust(top=0.855)
grid=plt.GridSpec(3,4, wspace=0.3, hspace=0.4)
fig2.suptitle("Distribution Metrics of Residuals")
    #stats.probplot(qqyh, dist="norm", plot=plt)
ax1=fig2.add_subplot(grid[:,0])
xnorm = np.linspace(-.5, .5, 100)
ynorm  = stats.norm.pdf(xnorm,scale=np.std(diffpts30m))
ax1.hist(diffpts30m,bins=30, density=True,color='mediumseagreen')
ax1.plot(xnorm,ynorm, color=ekf.pcolor,linewidth=1)

for d in range(4):
    dist=np.std(diffpts30m)*d
    ax1.plot([dist,dist],[0,stats.norm.pdf(dist,scale=np.std(diffpts30m))], color=ekf.pcolor,linewidth=0.5)
    ax1.plot([-dist,-dist],[0,stats.norm.pdf(dist,scale=np.std(diffpts30m))], color=ekf.pcolor,linewidth=0.5)
##    
ax1.set_title("Residuals from 30-yr Mean \n Throughout Timeseries")

sy=0
sx=10
##    
ax1.set_xlabel("Standard Deviation (K)")
label=axlabels[0]
trans = mtransforms.ScaledTranslation(-(sx+15)/72, sy/72, fig2.dpi_scale_trans) #-20/72, 7/72
ax1.text(0.0, 1.0, label, transform=ax1.transAxes + trans,fontsize='large', va='bottom')

label_list=["", "Skewness", "Kurtosis"]
eqslatex=[" $ \bar{\sigma} = $","$ \bar{\mu_3} / \bar{\sigma^3} = $","$ \bar{\mu_4} / \bar{\sigma^4} $"]
ax2=fig2.add_subplot(grid[0,1])
resdiffs=diffpts30m[len(diffpts30m)%10:].reshape((len(diffpts30m)//10, 10))
aavg=np.mean(diffpts30m)
ax2.plot(ekf.dates[15:-26:10],np.sqrt(np.mean((resdiffs-aavg)**2,axis=1 )),color='mediumseagreen')
ax2.set_title("Std. Deviation",y=0.95)
twTRstd=np.std(diffpts30m)
hts=[twTRstd,0,3]
print("Mean: " + str(aavg))
print("Stdev: " + str(twTRstd))

ax2.plot([ekf.dates[0],ekf.dates[-1]],[hts[0],hts[0]],ls="--",color=ekf.pcolor)
#ax2.set_xticks(np.arange(1850,2150,50))
ax2.tick_params(axis='x', labelsize=8)

label=axlabels[1]
trans = mtransforms.ScaledTranslation(-(sx)/72, sy/72, fig2.dpi_scale_trans) #-20/72, 7/72
ax2.text(0.0, 1.0, label, transform=ax2.transAxes + trans,fontsize='large', va='bottom')

for i in [1,2]:
    ax2=fig2.add_subplot(grid[i,1])
    moment=stats.moment(resdiffs,moment=i+2,axis=1,nan_policy='omit')
    sdmom=np.divide(moment,twTRstd**(i+2))
    ax2.plot(ekf.dates[15:-26:10],sdmom,color='mediumseagreen')
    ax2.set_title("Std. "+label_list[i],y=0.95)
    ax2.set_xticks(np.arange(1850,2150,50))
    ax2.tick_params(axis='x', labelsize=8)
    ax2.plot([ekf.dates[0],ekf.dates[-1]],[hts[i],hts[i]],color=ekf.pcolor)
    print(label_list[i]+": " +str(np.mean(sdmom)))
    label=axlabels[i+1]
    trans = mtransforms.ScaledTranslation(-(sx)/72, sy/72, fig2.dpi_scale_trans) #-20/72, 7/72
    ax2.text(0.0, 1.0, label, transform=ax2.transAxes + trans, fontsize='large', va='bottom')

axlabels=['e)','f)','g)','h)']
diffpts30m=-ekf.xblind[:,0]+temps

ax1=fig2.add_subplot(grid[:,2])
xnorm = np.linspace(-.5, .5, 100)
ynorm  = stats.norm.pdf(xnorm,scale=np.std(diffpts30m))
ax1.hist(diffpts30m,bins=30, density=True,color='darkgoldenrod')
ax1.plot(xnorm,ynorm, color=ekf.pcolor,linewidth=1)

for d in range(4):
    dist=np.std(diffpts30m)*d
    ax1.plot([dist,dist],[0,stats.norm.pdf(dist,scale=np.std(diffpts30m))], color=ekf.pcolor,linewidth=0.5)
    ax1.plot([-dist,-dist],[0,stats.norm.pdf(dist,scale=np.std(diffpts30m))], color=ekf.pcolor,linewidth=0.5)
##    
ax1.set_title("Residuals from Blind EBM \n Throughout Timeseries")

sy=0
sx=10
##    
ax1.set_xlabel("Standard Deviation (K)")
label=axlabels[0]
trans = mtransforms.ScaledTranslation(-(sx+15)/72, sy/72, fig2.dpi_scale_trans) #-20/72, 7/72
ax1.text(0.0, 1.0, label, transform=ax1.transAxes + trans,fontsize='large', va='bottom')

ax2=fig2.add_subplot(grid[0,3])
resdiffs=diffpts30m[4:].reshape((17, 10))
aavg=np.mean(diffpts30m)
ax2.plot(ekf.dates[4::10],np.sqrt(np.mean((resdiffs-aavg)**2,axis=1 )),color='darkgoldenrod')
ax2.set_title("Std. Deviation",y=0.95)
twTRstd=np.std(diffpts30m)
hts=[twTRstd,0,3]
print("Mean: " + str(aavg))
print("Stdev: " + str(twTRstd))

ax2.plot([ekf.dates[0],ekf.dates[-1]],[hts[0],hts[0]],ls="--",color=ekf.pcolor)
#ax2.set_xticks(np.arange(1850,2150,50))
ax2.tick_params(axis='x', labelsize=8)
label=axlabels[1]
trans = mtransforms.ScaledTranslation(-(sx)/72, sy/72, fig2.dpi_scale_trans) #-20/72, 7/72
ax2.text(0.0, 1.0, label, transform=ax2.transAxes + trans,fontsize='large', va='bottom')

for i in [1,2]:
    ax2=fig2.add_subplot(grid[i,3])
    moment=stats.moment(resdiffs,moment=i+2,axis=1,nan_policy='omit')
    sdmom=np.divide(moment,twTRstd**(i+2))
    ax2.plot(ekf.dates[4::10],sdmom,color='darkgoldenrod')
    ax2.set_title("Std. "+label_list[i],y=0.95)
    ax2.set_xticks(np.arange(1850,2150,50))
    ax2.tick_params(axis='x', labelsize=8)
    ax2.plot([ekf.dates[0],ekf.dates[-1]],[hts[i],hts[i]],color=ekf.pcolor)
    print(label_list[i]+": " +str(np.mean(sdmom)))
    label=axlabels[i+1]
    trans = mtransforms.ScaledTranslation(-(sx)/72, sy/72, fig2.dpi_scale_trans) #-20/72, 7/72
    ax2.text(0.0, 1.0, label, transform=ax2.transAxes + trans, fontsize='large', va='bottom')
    
fig2.savefig("dist_residuals.pdf", format="pdf")    
fig2.savefig("dist_residuals.png", dpi=400,format="png")

plt.show()
